-
Notifications
You must be signed in to change notification settings - Fork 329
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Various masked operations #2428
base: master
Are you sure you want to change the base?
Conversation
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
@@ -1050,6 +1050,9 @@ types, and on SVE/RVV. | |||
|
|||
* <code>V **AndNot**(V a, V b)</code>: returns `~a[i] & b[i]`. | |||
|
|||
* <code>V **MaskedOrOrZero**(M m, V a, V b)</code>: returns `a[i] || b[i]` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about a different naming convention here which might be a bit more natural?
There is also a MaskedLoad which returns 0 as the default, as opposed to MaskedLoadOr, which has the explicit default value. If we apply that here, we can just call it MaskedOr(m, a b), what do you think?
@@ -1050,6 +1050,9 @@ types, and on SVE/RVV. | |||
|
|||
* <code>V **AndNot**(V a, V b)</code>: returns `~a[i] & b[i]`. | |||
|
|||
* <code>V **MaskedOrOrZero**(M m, V a, V b)</code>: returns `a[i] || b[i]` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we mean a[i] | b[i]
?
@@ -2237,6 +2240,22 @@ The following `ReverseN` must not be called if `Lanes(D()) < N`: | |||
must be in the range `[0, 2 * Lanes(d))` but need not be unique. The index | |||
type `TI` must be an integer of the same size as `TFromD<D>`. | |||
* <code>V **TableLookupLanesOr**(M m, V a, V b, unspecified)</code> returns the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like we don't yet have an optimized version of these op, and it's just a convenience wrapper over IfThenElse. Would it be an option to move this into a utility function within your codebase? It's not clear whether this provides enough value to be a documented op that all readers must know.
IfThenElseZero(m, v)))` etc. The result is implementation-defined when all mask | ||
elements are false. | ||
* <code>T **MaskedReduceSum**(D, M m, V v)</code>: returns the sum of all lanes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! This looks useful.
Please add a TODO that we should also implement this for RVV.
#define HWY_NATIVE_MASKED_REDUCE_SCALAR | ||
#endif | ||
|
||
template <class D, class M> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a TODO here that we can remove the SumOfLanesM in favor of using MaskedReduceSum directly. This entails adding the D arg to HWY_SVE_REDUCE_ADD
as done in HWY_SVE_FIRSTN
.
@@ -4755,6 +4804,23 @@ HWY_API V IfNegativeThenElse(V v, V yes, V no) { | |||
static_assert(IsSigned<TFromV<V>>(), "Only works for signed/float"); | |||
return IfThenElse(IsNegative(v), yes, no); | |||
} | |||
// ------------------------------ IfNegativeThenNegOrUndefIfZero |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This op is undocumented, do we intend to add it? If so, let's add documentation and test.
@@ -219,6 +219,15 @@ HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SPECIALIZE, _, _) | |||
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ | |||
return sv##OP##_##CHAR##BITS(v); \ | |||
} | |||
#define HWY_SVE_RETV_ARGMV_M(BASE, CHAR, BITS, HALF, NAME, OP) \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor: we have the naming convention P for predicate, for example in HWY_SVE_RETV_ARGPVV
. I'm fine with either P or M, but let's please be consistent, feel free to pick one.
This might actually replace the existing HWY_SVE_RETV_ARGPV
.
} | ||
template <class D, class M> | ||
HWY_API TFromD<D> MaskedReduceMin(D d, M m, VFromD<D> v) { | ||
return ReduceMin(d, IfThenElse(m, v, MaxOfLanes(d, v))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems unnecessarily expensive, how about we replace MaxOfLanes with Set(d, hwy::HighestValue)?
} | ||
template <class D, class M> | ||
HWY_API TFromD<D> MaskedReduceMax(D d, M m, VFromD<D> v) { | ||
return ReduceMax(d, IfThenElseZero(m, v)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can get into trouble for signed values. If all values are negative, the presence of mask=false elements changes the result. Can similarly use hwy::LowestValue here?
Introduces:
a[i] || b[i]
orzero
ifm[i]
is false.TwoTablesLookupLanes(V a, V b, unspecified)
wherem[i]
is true, anda[i]
wherem[i]
is false.TwoTablesLookupLanes(V a, V b, unspecified)
wherem[i]
is true, and zero wherem[i]
is false.m[i]
istrue
.m[i]
istrue
.m[i]
istrue
.mask[i] < 0 ? (-v[i]) : ((mask[i] > 0) ? v[i] : impl_defined_val)
, whereimpl_defined_val
is an implementation-defined value that is equal to either 0 orv[i]
. SVE included only.Testing is performed for all new operations.